Image Generation with AutoEncoders¶
from data import CelebADataset, PARTITIONS
from models import VAE, GAN
from tensorflow.keras.optimizers.legacy import Adam
import plotly.graph_objects as go
from IPython.display import HTML
from utils import plot_history, plot_embeds
import pandas as pd
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
import numpy as np
import plotly.io as pio
pio.renderers.default = 'notebook_connected'
2024-05-25 18:57:56.746622: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2024-05-25 18:57:56.778250: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-25 18:57:56.778281: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-25 18:57:56.779270: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-05-25 18:57:56.784863: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-05-25 18:57:57.369900: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-05-25 18:57:58.395710: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22453 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:17:00.0, compute capability: 8.6 2024-05-25 18:57:58.396152: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 22324 MB memory: -> device: 1, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:65:00.0, compute capability: 8.6
Data preprocessing¶
We followed the CelebA dataset official split to create a train (80%), validation (10%) and test (10%) subsets. In order to create the same split, the script data.py automatically prepares the dataset subfolders to load and train our models (it is required to previously download the archive.zip from Kaggle).
# !python3 data.py
train, val, test = map(CelebADataset, ('train', 'val', 'test'))
print(f'Number of train samples: {len(train)}')
print(f'Number of validation samples: {len(val)}')
print(f'Number of test samples: {len(test)}')
Number of train samples: 162770 Number of validation samples: 19867 Number of test samples: 19962
As detailed in the class notebooks, we used two types of normalization methods: in the Variational AutoEncoder (see vae.py) the input images are scaled in the range $[0,1]$ (thus the final activation function is a sigmoid), while the GAN model accepts images in the range $[-1,1]$ and the generator module takes an hyperbolic tangent activation function. For these exercises, we used a resolution of $128\times 128$ to generate RGB images.
Baseline¶
Variational AutoEncoder¶
The Variational AutoEncoder (Pu et al., 2016) is implemented in the script vae.py. The class VAE accepts the following hyperparameters:
img_size: Image dimensions to feed the VAE. In this case, we used RGB images of dimension $128\times 128$.hidden_size: Dimension of the latent space.pool: Ifstrides, it uses doube-strides as a pooling method. Ifdilation, it uses a dilation rate of 2 followed by a max-pooling or up-sampling of factor 2.residual: Whether to use residual blocks.
The default configuration (used in class notebooks) uses a latent space of dimension $d_h = 200$ with double-strides and no residual nor skip connections:
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
vae = VAE(CelebADataset.IMG_SIZE, hidden_size=200, pool='strides', residual=False)
WARNING:tensorflow:From /home/ana/Documents/dl-labs/P3/.venv/lib/python3.10/site-packages/keras/src/layers/normalization/batch_normalization.py:883: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer.
2024-05-25 18:07:41.749467: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21136 MB memory: -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:17:00.0, compute capability: 8.6
2024-05-25 18:07:41.749602: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 22067 MB memory: -> device: 1, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:65:00.0, compute capability: 8.6
2024-05-25 18:07:41.899219: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
2024-05-25 18:07:42.402037: W tensorflow/c/c_api.cc:305] Operation '{name:'conv2d_80/kernel/Assign' id:5021 op device:{requested: '', assigned: ''} def:{{{node conv2d_80/kernel/Assign}} = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](conv2d_80/kernel, conv2d_80/kernel/Initializer/stateless_random_uniform)}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
The train() method of the class VAE allows fitting the Keras model with the following hyperparameters:
train: CelebA dataset instance with the training set.val: CelebA dataset instance with the validation set (used to track the losses and FID score).test: CelebA dataset instance with the evaluation set (used to save the final generated images).path: Folder to store all the training results.batch_size: Batch size.epochs: Number of training epochs. Default to 10.train_patience: Number of allowed epochs with no training improvement. Defaults to 5.val_patience: Number of allowed epochs with no validation improvement. Defaults to 5.steps_per_epoch: Number of batches per epoch. Defaults to 1500.optimizer: Keras Optimizer. Defaults to Adam with learning rate $\eta=10^{-4}$.
At training and validation time, the Keras API displays the performance of the model in the train and validation set in the following metrics:
r_loss: Defined as the mean squared error between the real and generated image.kl_loss: Defined as the KL divergence loss in the latent space.fid: FID score computed with the InceptionV3. Due to our computational limits, at training time we used a split of 500 samples to compute each FID score, resizing the images to $256\times 256$ resolution.
vae_history = vae.train(train, val, test, 'results/vae/', optimizer=Adam(5e-4), epochs=2, steps_per_epoch=100, batch_size=10)
Found 162770 files belonging to 1 classes. Found 19867 files belonging to 1 classes.
inception: 0%| | 0/7 [00:00<?, ?it/s]/home/ana/Documents/dl-labs/P3/.venv/lib/python3.10/site-packages/keras/src/engine/training_v1.py:2359: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
updates=self.state_updates,
2024-05-25 17:35:22.998115: W tensorflow/c/c_api.cc:305] Operation '{name:'global_average_pooling2d/Mean' id:6086 op device:{requested: '', assigned: ''} def:{{{node global_average_pooling2d/Mean}} = Mean[T=DT_FLOAT, Tidx=DT_INT32, _has_manual_control_dependencies=true, keep_dims=false](mixed10/concat, global_average_pooling2d/Mean/reduction_indices)}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
2024-05-25 17:35:23.481974: W tensorflow/c/c_api.cc:305] Operation '{name:'count_1/Assign' id:7635 op device:{requested: '', assigned: ''} def:{{{node count_1/Assign}} = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](count_1, count_1/Initializer/zeros)}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
2024-05-25 17:35:24.870978: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8907
2024-05-25 17:35:31.661444: W tensorflow/c/c_api.cc:305] Operation '{name:'loss/mul' id:7705 op device:{requested: '', assigned: ''} def:{{{node loss/mul}} = Mul[T=DT_FLOAT, _has_manual_control_dependencies=true](loss/mul/x, loss/decoder_loss/value)}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
2024-05-25 17:35:32.134481: W tensorflow/c/c_api.cc:305] Operation '{name:'training/Adam/iter/Assign' id:7977 op device:{requested: '', assigned: ''} def:{{{node training/Adam/iter/Assign}} = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_INT64, validate_shape=false](training/Adam/iter, training/Adam/iter/Initializer/zeros)}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
Train on 100 steps, validate on 100 steps Epoch 1/2 95/100 [===========================>..] - ETA: 0s - batch: 47.0000 - size: 1.0000 - loss: 78.0602 - r_loss: 0.0745 - kl_loss: 3.5351
/home/ana/Documents/dl-labs/P3/.venv/lib/python3.10/site-packages/keras/src/engine/training_v1.py:2335: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
updates = self.state_updates
2024-05-25 17:35:36.059992: W tensorflow/c/c_api.cc:305] Operation '{name:'loss/mul' id:7705 op device:{requested: '', assigned: ''} def:{{{node loss/mul}} = Mul[T=DT_FLOAT, _has_manual_control_dependencies=true](loss/mul/x, loss/decoder_loss/value)}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
2024-05-25 17:35:37.163897: W tensorflow/c/c_api.cc:305] Operation '{name:'decoder/output/Sigmoid' id:492 op device:{requested: '', assigned: ''} def:{{{node decoder/output/Sigmoid}} = Sigmoid[T=DT_FLOAT, _has_manual_control_dependencies=true](decoder/output/BiasAdd)}}' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session.
100/100 [==============================] - 17s 152ms/step - batch: 49.5000 - size: 1.0000 - loss: 77.1376 - r_loss: 0.0734 - kl_loss: 3.7442 - val_loss: 61.6159 - val_r_loss: 0.0531 - val_kl_loss: 8.5630 - fid: 384.2911 - val_fid: 384.5622 Epoch 2/2 100/100 [==============================] - ETA: 0s - batch: 49.5000 - size: 1.0000 - loss: 57.7901 - r_loss: 0.0493 - kl_loss: 8.4854
100/100 [==============================] - 12s 120ms/step - batch: 49.5000 - size: 1.0000 - loss: 57.7901 - r_loss: 0.0493 - kl_loss: 8.4854 - val_loss: 58.0304 - val_r_loss: 0.0495 - val_kl_loss: 8.5449 - fid: 365.6741 - val_fid: 372.4235
predict: 312it [00:31, 10.02it/s] eval: 311it [00:26, 11.84it/s] eval: 312it [00:26, 11.76it/s]
Once the training has ended, the folder used as the path variable will store the following elements:
results/
vae/
model.h5
history.pkl
results.pkl
epoch-preds/
...
test-preds/
...
- model.h5: The best model weights in terms of FID score in the validation set.
- history.pkl: The training history (losses and metrics).
- results.pkl: The final FID score of each set.
- val-preds: The generated images of the validation set during the training process (by default, 100 images are saved per epoch).
- test-preds: The generated images of the test set at the end of the trainingprocess, with the best weight configuration.
The function plot_history() (see utils.py) displays the progress of a given measure during the training stage.
plot_history(vae_history, 'loss').show()
plot_history(vae_history, 'kl_loss').show()
plot_history(vae_history, 'r_loss').show()
plot_history(vae_history, 'fid').show()
Finally, we can use the display() function (see utils.py) to display the images generated by the VAE.
real = next(test.stream(vae.NORM, 5))[1]
fake = vae.model.predict(real, verbose=0)
display(*map(vae.DENORM, (real, fake)))
WGAN-GP¶
Similarly, we used the implementation of the class notebooks to create a baseline with the WGAN-GP (Arjovsky et al., 2017). The class GAN (implemented in gan.py) builds and trains the adversarial model with the specific callbacks and metrics. This class accepts the following arguments:
img_size: Image size. For these exercises, we selected RGB images with resolution $128\times 128$.hidden_size: Dimension of the latent space.pool: Type of pooling ot use in the generator and discriminator. Ifstrides, it uses convolutions with double-strides. Ifdilationit uses dilated convolutions followed by max-pooling or up-sampling. By defaultstrides.residual: Whether to use residual blocks. By default it uses standard convolutional blocks.critic_steps: Number of steps training the discriminator before optimizing the generator.gp_weight: Optimization hyperparameter in the gradient penalty.
wgan = GAN(CelebADataset.IMG_SIZE, hidden_size = 128, pool='strides', residual=False, critic_steps=3, gp_weight=10)
To train the GAN model the smae arguments as explained in the VAE implementation must be used:
wgan_history = wgan.train(train, val, test, path='results/wgan', batch_size=20, epochs=10)
Found 162770 files belonging to 1 classes. WARNING:tensorflow:AutoGraph could not transform <bound method GAN.<lambda> of <models.gan.GAN object at 0x7a2057797190>> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Unknown node type <gast.gast.Assign object at 0x7a1f3cd652d0> To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING: AutoGraph could not transform <bound method GAN.<lambda> of <models.gan.GAN object at 0x7a2057797190>> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Unknown node type <gast.gast.Assign object at 0x7a1f3cd652d0> To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert Found 19867 files belonging to 1 classes.
inception: 0%| | 0/50 [00:00<?, ?it/s]2024-05-24 16:06:13.976071: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8907
Epoch 1/10
2024-05-24 16:06:25.261341: I external/local_xla/xla/service/service.cc:168] XLA service 0x5a3e09d44a00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: 2024-05-24 16:06:25.261363: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6 2024-05-24 16:06:25.261367: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (1): NVIDIA GeForce RTX 3090, Compute Capability 8.6 2024-05-24 16:06:25.267984: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1716559585.363076 502984 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
WARNING:tensorflow:5 out of the last 5 calls to <function _BaseOptimizer._update_step_xla at 0x7a1ee01bdb40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. WARNING:tensorflow:6 out of the last 6 calls to <function _BaseOptimizer._update_step_xla at 0x7a1ee01bdb40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. 1500/1500 [==============================] - ETA: 0s - c_loss: -2.8608 - c_wass_loss: -3.4819 - c_gp: 0.0621 - g_loss: 8.3531 - c_acc: 15.6801 - g_acc: 95.9855
1500/1500 [==============================] - 531s 351ms/step - c_loss: -2.8608 - c_wass_loss: -3.4819 - c_gp: 0.0621 - g_loss: 8.3531 - c_acc: 15.6705 - g_acc: 95.9881 - val_c_acc: 0.0000e+00 - val_g_acc: 100.0000 - fid: 374.6558 - val_fid: 372.9523 Epoch 2/10 1500/1500 [==============================] - ETA: 0s - c_loss: -0.7196 - c_wass_loss: -0.7840 - c_gp: 0.0064 - g_loss: 5.6114 - c_acc: 0.9301 - g_acc: 99.3868
1500/1500 [==============================] - 530s 353ms/step - c_loss: -0.7196 - c_wass_loss: -0.7840 - c_gp: 0.0064 - g_loss: 5.6114 - c_acc: 0.9319 - g_acc: 99.3869 - val_c_acc: 0.0000e+00 - val_g_acc: 100.0000 - fid: 402.8370 - val_fid: 400.5402 Epoch 3/10 1500/1500 [==============================] - ETA: 0s - c_loss: -1.2087 - c_wass_loss: -1.3187 - c_gp: 0.0110 - g_loss: 3.4947 - c_acc: 15.1255 - g_acc: 91.8309
1500/1500 [==============================] - 528s 352ms/step - c_loss: -1.2087 - c_wass_loss: -1.3187 - c_gp: 0.0110 - g_loss: 3.4947 - c_acc: 15.1356 - g_acc: 91.8189 - val_c_acc: 0.0000e+00 - val_g_acc: 100.0000 - fid: 408.8687 - val_fid: 406.1084 Epoch 4/10 1500/1500 [==============================] - ETA: 0s - c_loss: -0.8362 - c_wass_loss: -0.8864 - c_gp: 0.0050 - g_loss: 2.7144 - c_acc: 29.4165 - g_acc: 79.6204
1500/1500 [==============================] - 525s 350ms/step - c_loss: -0.8362 - c_wass_loss: -0.8864 - c_gp: 0.0050 - g_loss: 2.7144 - c_acc: 29.4149 - g_acc: 79.6254 - val_c_acc: 2.2000 - val_g_acc: 100.0000 - fid: 490.2902 - val_fid: 482.3736 Epoch 5/10 1500/1500 [==============================] - ETA: 0s - c_loss: -0.9129 - c_wass_loss: -0.9669 - c_gp: 0.0054 - g_loss: 1.4458 - c_acc: 42.9939 - g_acc: 72.4017
1500/1500 [==============================] - 531s 354ms/step - c_loss: -0.9129 - c_wass_loss: -0.9669 - c_gp: 0.0054 - g_loss: 1.4458 - c_acc: 43.0167 - g_acc: 72.3857 - val_c_acc: 100.0000 - val_g_acc: 48.6000 - fid: 502.4674 - val_fid: 498.0322 Epoch 6/10 1500/1500 [==============================] - ETA: 0s - c_loss: -0.9461 - c_wass_loss: -1.0052 - c_gp: 0.0059 - g_loss: 0.2246 - c_acc: 66.9077 - g_acc: 54.9823
1500/1500 [==============================] - 534s 356ms/step - c_loss: -0.9461 - c_wass_loss: -1.0052 - c_gp: 0.0059 - g_loss: 0.2246 - c_acc: 66.9147 - g_acc: 54.9793 - val_c_acc: 99.6000 - val_g_acc: 60.4000 - fid: 469.1463 - val_fid: 464.2170 Epoch 7/10 1500/1500 [==============================] - ETA: 0s - c_loss: -0.9770 - c_wass_loss: -1.0419 - c_gp: 0.0065 - g_loss: 0.5904 - c_acc: 65.7503 - g_acc: 70.0368
1500/1500 [==============================] - 535s 356ms/step - c_loss: -0.9770 - c_wass_loss: -1.0419 - c_gp: 0.0065 - g_loss: 0.5904 - c_acc: 65.7523 - g_acc: 70.0353 - val_c_acc: 86.6000 - val_g_acc: 96.8000 - fid: 462.3430 - val_fid: 472.8241 Epoch 8/10 1500/1500 [==============================] - ETA: 0s - c_loss: -1.2337 - c_wass_loss: -1.3293 - c_gp: 0.0096 - g_loss: 1.8328 - c_acc: 37.8011 - g_acc: 83.6471
1500/1500 [==============================] - 534s 356ms/step - c_loss: -1.2337 - c_wass_loss: -1.3293 - c_gp: 0.0096 - g_loss: 1.8328 - c_acc: 37.7920 - g_acc: 83.6517 - val_c_acc: 0.4000 - val_g_acc: 100.0000 - fid: 456.8448 - val_fid: 454.1087 Epoch 9/10 1500/1500 [==============================] - ETA: 0s - c_loss: -1.0129 - c_wass_loss: -1.0739 - c_gp: 0.0061 - g_loss: 2.3282 - c_acc: 23.2813 - g_acc: 91.6023
1500/1500 [==============================] - 528s 352ms/step - c_loss: -1.0129 - c_wass_loss: -1.0739 - c_gp: 0.0061 - g_loss: 2.3282 - c_acc: 23.2863 - g_acc: 91.5991 - val_c_acc: 5.8000 - val_g_acc: 100.0000 - fid: 433.3018 - val_fid: 429.7416 Epoch 10/10 1500/1500 [==============================] - ETA: 0s - c_loss: -1.0229 - c_wass_loss: -1.0896 - c_gp: 0.0067 - g_loss: 1.5077 - c_acc: 37.4816 - g_acc: 90.7917
1500/1500 [==============================] - 526s 351ms/step - c_loss: -1.0229 - c_wass_loss: -1.0896 - c_gp: 0.0067 - g_loss: 1.5077 - c_acc: 37.5001 - g_acc: 90.7891 - val_c_acc: 61.4000 - val_g_acc: 100.0000 - fid: 446.2078 - val_fid: 445.7891
plot_history(wgan_history, name=['c_acc', 'g_acc', 'val_c_acc', 'val_g_acc']).update_xaxes(title_text='acc')
plot_history(wgan_history, name='fid')
real = next(test.stream(wgan.NORM, 5))[1]
fake = wgan.model(real)
display(*map(wgan.DENORM, (real, fake)))
We see that the baseline models do not reach acceptable results in image generation. In the next sections we show our results when varying the architecture and extending the training time.
Modifying the architecture¶
In the previous demonstration of the generative baselines we used a limited latent space ($d_h^\text{(vae)} = 200$ and $d_h^\text{(wgan)}=128$) and training time. In order to test the capacities of the VAE and WGAN-GP models, we performed experiments increasing the dimension of the latent space ($\{128, 256, 512\}$), using residual blocks instead of standard convolutions and increasing the dilation of the convolutional layers to expand the receptive field of the network. To facilitate readibility of this notebook, we do not include the Keras output of each model, but a summary of the performance obtained with each configuration. Similar results can be reproduced followng the README.md file attached with the code.
FID scores obtained with VAE:
| $d_h$ | pool | residual | train | val | test |
|---|---|---|---|---|---|
| 128 | strides | - | |||
| 256 | strides | - | |||
| 512 | strides | - | |||
| 512 | dilation | - | |||
| 512 | strides | + | |||
| 512 | dilation | + |
FID scores obtained with WGAN-GP:
| $d_h$ | pool | residual | train | val | test |
|---|---|---|---|---|---|
| 128 | strides | - | |||
| 256 | strides | - | |||
| 512 | strides | - | |||
| 512 | dilation | - | |||
| 512 | strides | + | |||
| 512 | dilation | + |
Exploring the latent space¶
In order to explore the latent space obtained in each model we visualized the contextualized vector obtained with each model (VAE and GAN) projected into a 3-dimensional space. With the aim of grouping the different projections of the input images, we relied on the annotations provided in the CelebA dataset about different characteristics of the person recorded in each image.
The list_attr_celeba.csv file contains binary features of each image (such as gender, hairstyle, eyeglasses, age, facial characteristics, etc.). In order to evaluate the VAE latent space we can classify some images in different categories (for instance, by hairstyle) and project their embeddings into a representable space (e.g. 2 or 3 dimensions). For this purpose, we used t-SNE (Maaten et al., 2008) to project the set of selected embeddings of dimension $d_h^\text{(vae)}$ and represented them in a 3-dimensional space.
info = pd.read_csv('archive/list_attr_celeba.csv', index_col=0)
partition = pd.read_csv('archive/list_eval_partition.csv', index_col=0)
info = info.join(partition)
info.columns
Index(['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
'Wearing_Necklace', 'Wearing_Necktie', 'Young', 'partition'],
dtype='object')
vae = VAE(CelebADataset.IMG_SIZE, hidden_size=512, pool='strides', residual=False)
vae.model.load_weights('results/vae/vae_512_strides/model.h5')
categories = ['Bald', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']
LIMIT = 50
images, labels = [], []
for c in categories:
selection = info[info[c] == 1][:LIMIT]
images += [f'archive/{PARTITIONS[partition]}/{file}' for file, partition in zip(selection.index, selection.partition)]
labels += [c for _ in range(LIMIT)]
embed = vae.latent(images)
lda = LDA(n_components=3)
proj = lda.fit_transform(np.stack(embed), np.array(labels))
fig = plot_embeds(proj, categories)
def show(fig: go.Figure, path: str):
fig.write_html(path)
fig.show()
HTML(open(path).read())
show(fig, 'figures/embeds.html')
HTML(open('figures/embeds.html').read())
latent: 0%| | 0/1 [00:00<?, ?it/s]2024-05-25 18:58:07.816431: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8907